# Copyright (c) OpenMMLab. All rights reserved.
import json
import os
import time

import cv2
import h5py
import numpy as np

mat_files = ['COFW_train_color.mat', 'COFW_test_color.mat']
dataset_dir = 'data/cofw/'

image_root = os.path.join(dataset_dir, 'images/')
annotation_root = os.path.join(dataset_dir, 'annotations/')

os.makedirs(image_root, exist_ok=True)
os.makedirs(annotation_root, exist_ok=True)

cnt = 0
for mat_file in mat_files:
    mat = h5py.File(os.path.join(dataset_dir, mat_file), 'r')

    if 'train' in mat_file:
        imgs = mat['IsTr']
        pts = mat['phisTr']
        bboxes = mat['bboxesTr']
        is_train = True
        json_file = 'cofw_train.json'
    else:
        imgs = mat['IsT']
        pts = mat['phisT']
        bboxes = mat['bboxesT']
        is_train = False
        json_file = 'cofw_test.json'

    images = []
    annotations = []

    num = pts.shape[1]
    for idx in range(0, num):
        cnt += 1
        img = np.array(mat[imgs[0, idx]]).transpose()
        keypoints = pts[:, idx].reshape(3, -1).transpose()
        # 2 for valid and 1 for occlusion
        keypoints[:, 2] = 2 - keypoints[:, 2]
        # matlab 1-index to python 0-index
        keypoints[:, :2] -= 1
        bbox = bboxes[:, idx]

        # check nonnegativity
        bbox[bbox < 0] = 0
        keypoints[keypoints < 0] = 0

        image = {}
        image['id'] = cnt
        image['file_name'] = f'{str(cnt).zfill(6)}.jpg'
        image['height'] = img.shape[0]
        image['width'] = img.shape[1]
        cv2.imwrite(
            os.path.join(image_root, image['file_name']),
            cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
        images.append(image)

        anno = {}
        anno['keypoints'] = keypoints.reshape(-1).tolist()
        anno['image_id'] = cnt
        anno['id'] = cnt
        anno['num_keypoints'] = len(keypoints)  # all keypoints are labelled
        anno['bbox'] = bbox.tolist()
        anno['iscrowd'] = 0
        anno['area'] = anno['bbox'][2] * anno['bbox'][3]
        anno['category_id'] = 1

        annotations.append(anno)

    cocotype = {}

    cocotype['info'] = {}
    cocotype['info']['description'] = 'COFW Generated by MMPose Team'
    cocotype['info']['version'] = '1.0'
    cocotype['info']['year'] = time.strftime('%Y', time.localtime())
    cocotype['info']['date_created'] = time.strftime('%Y/%m/%d',
                                                     time.localtime())

    cocotype['images'] = images
    cocotype['annotations'] = annotations
    cocotype['categories'] = [{
        'supercategory': 'person',
        'id': 1,
        'name': 'face',
        'keypoints': [],
        'skeleton': []
    }]

    ann_path = os.path.join(annotation_root, json_file)
    json.dump(cocotype, open(ann_path, 'w'))
    print(f'done {ann_path}')
